import torch
from tqdm import tqdm
import random
from llava_v15_utils import prompt_wrapper, generator
from torchvision.utils import save_image
import torch.nn.functional as F
import numpy as np
import matplotlib
matplotlib.use('Agg')
from matplotlib import pyplot as plt
from matplotlib.widgets import MultiCursor
import seaborn as sns
import torch
import torch.nn as nn
from torch.nn import CrossEntropyLoss
from llava_v15.model.builder import load_pretrained_model
from llava_v15.mm_utils import tokenizer_image_token
from llava_v15.constants import DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX
from llava_v15.conversation import conv_templates
from llava_v15_utils import generator
from torchvision import transforms
from PIL import Image
import copy
import os
import csv
import sys
        
def normalize(images):
    mean = torch.tensor([0.48145466, 0.4578275, 0.40821073]).cuda()
    std = torch.tensor([0.26862954, 0.26130258, 0.27577711]).cuda()
    images = images - mean[None, :, None, None]
    images = images / std[None, :, None, None]
    return images

def denormalize(images):
    mean = torch.tensor([0.48145466, 0.4578275, 0.40821073]).cuda()
    std = torch.tensor([0.26862954, 0.26130258, 0.27577711]).cuda()
    images = images * std[None, :, None, None]
    images = images + mean[None, :, None, None]
    return images


class Attacker:

    def __init__(self, args, model, tokenizer, targets, device='cuda:0', is_rtp=False, image_processor=None, run = None):

        self.args = args
        self.model = model
        self.tokenizer= tokenizer
        self.device = device
        self.is_rtp = is_rtp

        self.targets = targets
        self.num_targets = len(targets)

        self.loss_buffer = []
        self.run = run

        # freeze and set to eval model:
        self.model.eval()
        self.model.requires_grad_(False)

        self.image_processor = image_processor


    def targeted_attack_B2H(self, img, batch_size = 1, num_iter=2000, alpha=1/255, epsilon = 128/255 , ours = True, before = True):
        batch_size = self.args.batch_size
        print('>>> batch_size:', batch_size)

        my_generator = generator.Generator(model=self.model, tokenizer=self.tokenizer)

        if self.args.ours :
            file = open("harmful_corpus/derogatory_corpus.csv", "r") 
            data = list(csv.reader(file, delimiter=","))
            file.close()
            targets = []
            num = len(data)
            for i in range(num):
                targets.append(data[i][0])
            self.targets = targets
            
            file = open("harmful_corpus/benign_sentences.csv", "r") 
            data = list(csv.reader(file, delimiter=","))
            file.close()
            targets = []
            num = len(data)
            for i in range(num):
                targets.append(data[i][0])
            self.targets_1 = targets
            
            file = open("harmful_corpus/harmful_words.csv", "r") 
            data = list(csv.reader(file, delimiter=","))
            file.close()
            targets = []
            num = len(data)
            for i in range(num):
                targets.append(data[i][0])
            self.targets_2 = targets

        question_adv = ''
        conv_template = "llava_v1" 
        question_adv = DEFAULT_IMAGE_TOKEN + f"\n{question_adv}"
        conv = conv_templates[conv_template].copy()
        conv.append_message(conv.roles[0], question_adv)
        conv.append_message(conv.roles[1], None)
        text_prompt_template_adv = conv.get_prompt()
        prompt_adv = tokenizer_image_token(text_prompt_template_adv, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).to(self.device)

        for step in range(1):
            class_n = step
            print('######### step = %d ##########' % step)
            adv_noise = torch.rand_like(img).cuda() * 2 * epsilon - epsilon
            adv_noise = adv_noise.cuda()
            adv_noise.requires_grad_(True)
            adv_noise.retain_grad()
            
            batch_targets = random.sample(self.targets, batch_size)
 
            print("batch_targets >>>", batch_targets)
            x = denormalize(img).clone().to(self.device)
            
            for t in tqdm(range(num_iter + 1)):
                

                x_org = denormalize(img).clone().to(self.device)
                x_adv = x_org + adv_noise 
                x_org = normalize(x_org)
                x_adv = normalize(x_adv)
                    
                batch_targets = random.sample(self.targets, batch_size) 
                th = self.args.th
                target_loss, _ = self.attack_loss(prompt_adv, x_adv, batch_targets, th = th)
                target_loss.backward()
                
                ### PGD
                adv_noise.data = (adv_noise.data - alpha * adv_noise.grad.detach().sign()).clamp(-epsilon, epsilon)
                adv_noise.grad.zero_()

                adv_noise.data = adv_noise.data.clamp(-epsilon, epsilon)
                adv_noise.data = (adv_noise.data + x.to(self.device).data).clamp(0, 1) - x.to(self.device).data
                self.model.zero_grad()

                self.loss_buffer.append(target_loss.item())
                if self.run is not None:
                    self.run[f'target_loss_step{step}'].append(target_loss.item())
                
                for ii in range(batch_size):
                    class_n = step
                    if t % 1000 == 0:
                        print('######### Step = %d ##########' % t)
                        if self.run is not None: 
                            self.run[f'adv_value'].append(adv_noise.max())
                        x_org = denormalize(img).clone().to(self.device)
                        x_adv = x_org + adv_noise 
                        x_org = normalize(x_org)
                        x_adv = normalize(x_adv)
                        adv_img_prompt = denormalize(x_adv).detach().cpu()
                        adv_img_prompt = adv_img_prompt.squeeze(0)
                        save_image(adv_img_prompt, '%s/class%d_iter%d.png' % (self.args.save_dir, class_n, t))
                        torch.save(x_adv, '%s/class%d_iter%d.pt' % (self.args.save_dir, class_n, t))
                        if t == 0:
                            adv_img_prompt = denormalize(x_org).detach().cpu()
                            adv_img_prompt = adv_img_prompt.squeeze(0)
                            save_image(adv_img_prompt, '%s/class%d_iter%d.png' % (self.args.save_dir, class_n, t))
                            torch.save(x_org, '%s/class%d_iter%d.pt' % (self.args.save_dir, class_n, t))
                       
        return None


    def attack_loss(self, prompts, images, targets, rand_prob = 0, th = 0.5, many_shot_num = 1):

        if self.args.ours:
            rand_prob = random.random() 
        else :
            rand_prob = 2
            
        prompts_org = prompts
        prompts = [prompts]
        context_length = [prompts_org.shape[-1]]
        context_input_ids = prompts
        batch_size = len(targets)

        if len(context_input_ids) == 1:
            context_length = context_length * batch_size
            context_input_ids = context_input_ids * batch_size

        images = images.repeat(batch_size, 1, 1, 1)

        assert len(context_input_ids) == len(targets), f"Unmathced batch size of prompts and targets {len(context_input_ids)} != {len(targets)}"
            
        if th < rand_prob : # label 1 = label 2
            batch_targets = random.sample(self.targets, many_shot_num)
        else : # label 1 != label 2
            batch_targets = random.sample(self.targets_1, many_shot_num)
            
        bad_sentence = [" ".join(batch_targets)] * batch_size
        to_regress_tokens = [ torch.as_tensor([item[0:]]).cuda() for item in self.tokenizer([bad_sentence][0]).input_ids] 
        if self.args.ours:
            batch_targets_2 = random.sample(self.targets_2, min(50, len(self.targets_2)))
            bad_sentence_2 = " ".join(batch_targets_2)
            bad_sentence_2 = (bad_sentence_2 + " ") * 10
            bad_sentence_2 = bad_sentence_2.strip() 
            to_regress_tokens_2 = [ torch.as_tensor([item[0:]]).cuda() for item in self.tokenizer([bad_sentence_2]).input_ids] 
            labels_2 = []

        seq_tokens_length = []
        labels = []
        input_ids = []

        for i, item in enumerate(to_regress_tokens):

            L = item.shape[1] + context_length[i]
            seq_tokens_length.append(L)

            context_mask = torch.full([1, context_length[i]], -100,
                                      dtype=to_regress_tokens[0].dtype,
                                      device=to_regress_tokens[0].device)
            labels.append( torch.cat( [context_mask, item], dim=1 ) )
            if self.args.ours:
                long_tokens = to_regress_tokens_2[i].repeat(1,10)
                labels_2.append( torch.cat( [context_mask, long_tokens[:, :item.shape[1]]], dim=1 ) )
            input_ids.append( torch.cat( [context_input_ids[i], item], dim=1 ) )

        # padding token
        pad = torch.full([1, 1], 0,
                         dtype=to_regress_tokens[0].dtype,
                         device=to_regress_tokens[0].device).cuda() 


        max_length = max(seq_tokens_length)
        attention_mask = []

        for i in range(batch_size):

            # padding to align the length
            num_to_pad = max_length - seq_tokens_length[i]

            padding_mask = (
                torch.full([1, num_to_pad], -100,
                       dtype=torch.long,
                       device=self.device)
            )
            labels[i] = torch.cat( [labels[i], padding_mask], dim=1 )
            if self.args.ours:
                labels_2[i] = torch.cat( [labels_2[i], padding_mask], dim=1 )
            input_ids[i] = torch.cat( [input_ids[i],
                                       pad.repeat(1, num_to_pad)], dim=1 )
            attention_mask.append( torch.LongTensor( [ [1]* (seq_tokens_length[i]) + [0]*num_to_pad ] ) )

        labels = torch.cat( labels, dim=0 ).cuda()
        if self.args.ours:
            labels_2 = torch.cat( labels_2, dim=0 ).cuda()    
        input_ids = torch.cat( input_ids, dim=0 ).cuda()
        attention_mask = torch.cat(attention_mask, dim=0).cuda()

        if self.args.ours:
            if th < rand_prob : # label 1 = label 2
                outputs = self.model(
                    input_ids=input_ids, 
                    attention_mask=attention_mask,
                    return_dict=True,
                    labels=labels, 
                    labels_2=labels, 
                    images=images.half(),
                    output_hidden_states = True,
                )
            else :
                outputs = self.model(
                    input_ids=input_ids, 
                    attention_mask=attention_mask, 
                    return_dict=True,
                    labels=labels, 
                    labels_2=labels_2, 
                    images=images.half(),
                    output_hidden_states = True,
                )
            
        else:
            outputs = self.model(
                    input_ids=input_ids, 
                    attention_mask=attention_mask, 
                    return_dict=True,
                    labels=labels, 
                    # labels_2=None,
                    images=images.half(),
                    output_hidden_states = True,
            )

        loss = outputs.loss
        logits = outputs.logits

        return loss, logits
    
        